import os
import subprocess
from concurrent.futures import ThreadPoolExecutor, as_completed

import librosa


def convert_to_mkv_filename(old_filename):
    base = old_filename[:-4]
    video_id = base[:-7]
    start_time = int(base[-6:])
    end_time = start_time + 10

    new_filename = f"v{video_id}_{start_time}_{end_time}_out.mkv"
    return new_filename


exist_list = []
err_list = []


def convert_mkv_to_mp4(old_filename):
    try:
        old_dir = os.path.dirname(old_filename[0])
        old_basename = os.path.basename(old_filename[1])

        new_filename = convert_to_mkv_filename(old_filename[1])
        new_filename = os.path.join(old_filename[0], new_filename)
        target_filename = os.path.join(old_dir, old_basename)

        # if os.path.exists(target_filename):
        #     exist_list.append(new_filename)
        #     # print("not exist:" + new_filename.replace('.mkv', ''))
        #     return
        if not os.path.exists(new_filename):
            # print("not exist:" + new_filename)
            return
        # print(new_filename, target_filename)
        new_filename = new_filename.replace(os.sep, '/')
        target_filename = target_filename.replace(os.sep, '/')

        command = [
            'ffmpeg',
            "-loglevel",
            "quiet",
            '-y',
            '-i', new_filename,
            '-c:v', 'copy', 
            '-c:a', 'aac', 
            '-b:a', '16k', 
            target_filename
        ]
        result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

        if result.returncode != 0:
            err_list.append(new_filename)
            # os.remove(new_filename)
            print(f"Failed to convert {new_filename}: {result.stderr.decode()}")
        os.remove(new_filename)
    except subprocess.CalledProcessError as e:
        print(f"Failed to convert {new_filename}: {e}")


import tqdm


def convert_files(file_list, num_threads=4):
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        futures = {executor.submit(convert_mkv_to_mp4, file): file for file in tqdm.tqdm(file_list)}
        for future in as_completed(futures):
            try:
                future.result()
            except Exception as e:
                print(f"Error processing file {futures[future]}: {e}")
    print(err_list.__len__())
    print(exist_list.__len__())


import os
import subprocess
from concurrent.futures import ThreadPoolExecutor, as_completed


def extract_audio_from_mp4(mp4_filename, wav_filename):
    try:
        if os.path.exists(wav_filename):
            return
        command = [
            'ffmpeg',
            "-loglevel",
            "quiet",
            '-n', 
            '-i', mp4_filename,
            '-vn',  
            '-acodec', 'pcm_s16le',  
            '-ar', '16000', 
            '-ac', '2', 
            wav_filename
        ]
        subprocess.run(command, check=True)
        # print(f"Extracted audio from {mp4_filename} to {wav_filename}")
    except subprocess.CalledProcessError as e:
        print(f"Failed to extract audio from {mp4_filename}: {e}")


def pad_wav_to_10_seconds(wav_filename):
    try:
        duration = get_audio_duration(wav_filename)
        command = [
            'ffprobe',
            '-v', 'error',
            '-count_frames',
            '-select_streams', 'v:0',
            '-show_entries', 'stream=nb_frames',
            '-of', 'default=noprint_wrappers=1:nokey=1',
            wav_filename.replace('.wav', '_prep.mp4')
        ]
        result = subprocess.run(command, capture_output=True, text=True)
        total_frames = int(result.stdout.strip())
        if duration < 9 or total_frames < 45:
            print(duration, total_frames, wav_filename)
            os.remove(wav_filename)
            os.remove(wav_filename.replace('.wav', '.mp4'))
            os.remove(wav_filename.replace('.wav', '_prep.mp4'))
    except subprocess.CalledProcessError as e:
        print(f"Failed to pad {wav_filename}: {e}")


def get_audio_duration(audio_filename):
    command = [
        'ffprobe',
        '-v', 'error',
        '-show_entries', 'format=duration',
        '-of', 'default=noprint_wrappers=1:nokey=1',
        audio_filename
    ]
    result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    duration = float(result.stdout.strip())
    return duration


def process_mp4_file(mp4_file):
    try:
        wav_file = mp4_file.replace('.mp4', '.wav')
        extract_audio_from_mp4(mp4_file, wav_file)
        pad_wav_to_10_seconds(wav_file)
    except Exception as e:
        print(f"Error processing file {mp4_file}: {e}")


def process_mp4_files(mp4_files, num_threads=60):
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        futures = [executor.submit(process_mp4_file, mp4_file) for mp4_file in tqdm.tqdm(mp4_files)]
        for future in as_completed(futures):
            future.result()


def preprocess_video(input_file):
    try:
        output_file = input_file.replace('.mp4', '_prep.mp4')
        if os.path.exists(output_file):
            return
        command = [
            'ffmpeg',
            # "-loglevel",
            # "quiet",
            '-i', input_file,
            '-vf',
            'fps=5,scale=256:192:force_original_aspect_ratio=decrease,pad=256:192:-1:-1:color=white,crop=256:192',
            '-c:v', 'libx264', 
            '-preset', 'medium',
            '-crf', '23', 
            '-an', 
            '-n', 
            output_file
        ]
        subprocess.run(command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

        command = [
            'ffprobe',
            '-v', 'error',
            '-count_frames',
            '-select_streams', 'v:0',
            '-show_entries', 'stream=nb_frames',
            '-of', 'default=noprint_wrappers=1:nokey=1',
            output_file
        ]
        result = subprocess.run(command, capture_output=True, text=True)
        total_frames = int(result.stdout.strip())

        if total_frames < 45:
            print(f"Warning: {output_file} has less than 45 frames")

        elif total_frames < 50:
            padding_frames = 50 - total_frames
            command = [
                'ffmpeg',
                '-i', output_file,
                '-vf', f"fps=5,select=eq(n\\,{total_frames}-1)",
                '-vsync', '0',
                '-frames:v', str(padding_frames),
                '-y',
                output_file
            ]
            subprocess.run(command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
            print(f"Padded {output_file} to 50 frames")

        # print(f"Preprocessed {input_file} and saved as {output_file}")
    except subprocess.CalledProcessError as e:
        os.remove(input_file)
        # os.remove(input_file.replace('.mp4', '.wav'))
        print(f"Failed to preprocess {input_file}: {e}")


def preprocess_videos(video_files, num_threads=4):
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        futures = [executor.submit(preprocess_video, video_file) for video_file in video_files]
        for future in as_completed(futures):
            future.result()


import argparse
import os
import csv
import numpy as np
import torchaudio
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils


class GetAudioVideoDataset(Dataset):

    def __init__(self, args, mode='train', transforms=None):
        # data2path = {}
        classes = []
        # classes_ = []
        # data = []
        # path = []
        self.video_files = []
        # self.data2class = {}
        # data2class = {}

        with open(args.csv_path + 'stat.csv') as f:
            csv_reader = csv.reader(f)
            for row in csv_reader:
                classes.append(row[0])

        with (open(args.csv_path + args.mode + '.csv') as f):
            csv_reader = csv.reader(f)
            for item in csv_reader:
                cur_path = args.data_path + args.mode + '/' + item[1
                ].replace(' ', '_').replace(',', '').replace('.', '_') + '/'
                # print(item[1].replace(' ', '_').replace(',', ''))
                cur_name = item[0][:-3] + 'mp4'
                path = os.path.join(cur_path, cur_name)
                if item[1] in classes and os.path.exists(path):
                    self.video_files.append(path)
                # if item[1] in classes and os.path.exists(path):
                #     # data.append(item[0])
                #     self.video_files.append([path, item[1]])
                # self.data2class[cur_path] = item[1]

        # self.audio_path = args.data_path + args.mode
        self.mode = mode
        self.transforms = transforms
        self.classes = sorted(classes)

        # initialize audio transform
        self._init_atransform()
        #  Retrieve list of audio and video files

        # for item in data:
        #    self.video_files.append(item)
        print('# of audio files = %d ' % len(self.video_files))
        print('# of classes = %d' % len(self.classes))
        # preprocess_videos(self.video_files)
        process_mp4_files(self.video_files)

    def _init_atransform(self):
        self.aid_transform = transforms.Compose([transforms.ToTensor()])

    def __len__(self):
        return len(self.video_files)

    def __getitem__(self, idx):
        wav_data = self.video_files[idx]
        # Audio
        # samples, samplerate = torchaudio.load(wav_data[0])
        # if not os.path.exists(wav_data[0].replace('.mp4', '.wav')):
        #     print("remove:" + wav_data)
        # os.remove(wav_data[0])
        # os.remove(wav_data[0].replace('.mp4', '.wav'))
        # return 1, 1
        samples, samplerate = librosa.load(wav_data[0].replace('.mp4', '.wav'))
        if samples.__len__() < 157000:
            print(wav_data[0], samples.__len__())
        samples = np.pad(samples, (0, max(0, 160000 - len(samples))))[:160000]
        # repeat in case audio is too short
        return samples, self.classes.index(wav_data[1])
        # return spectrogram, resamples, self.classes.index(wav_data[1]), wav_data[0]